from os import replace
import pandas as pd
import torch
import spacy
import nltk,random
from nltk.corpus import stopwords
from nltk.stem.snowball import SnowballStemmer
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from scipy import spatial
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification
import stanza,pickle
import re,random,string
# from nltk.corpus import wordnet
import gensim.downloader as api

syn_model = api.load('glove-twitter-25')  

seed = 1234


tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
ner_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")

tok = pipeline("ner", model= ner_model, tokenizer=tokenizer)

with open('../names.pkl','rb') as f:
    name_list = pickle.load(f)
with open('../city.pkl','rb') as f:
    city_list = pickle.load(f)
    
# translator_en_ru = pipeline("translation", model="Helsinki-NLP/opus-mt-en-ru")
# translator_ru_en = pipeline("translation", model="Helsinki-NLP/opus-mt-ru-en")
# translator_en_de = pipeline("translation", model="Helsinki-NLP/opus-mt-en-de")
# translator_de_fr = pipeline("translation", model="Helsinki-NLP/opus-mt-de-fr")
# translator_fr_en = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-en")

tokenizer = AutoTokenizer.from_pretrained("ceshine/t5-paraphrase-quora-paws")

nlp = spacy.load('en_core_web_sm')
qtext_model = AutoModelForSeq2SeqLM.from_pretrained("ceshine/t5-paraphrase-quora-paws")

unmasker = pipeline('fill-mask', model='bert-base-uncased')
snow_stemmer = SnowballStemmer(language='english')

stop_words = stopwords.words('english')
punct = string.punctuation
extend_list = ["number0","number1","number2","number3","number4","number5","number6","number7","number8"]
stop_words.extend(extend_list)
# stop_words.extend(["?",",",".","How","What","many","more","year","Why","When","one","couple","could"])
punct = [w for w in punct]
stop_words.extend(punct)
stop_words = set(stop_words)
# print("stop_words: ",stop_words)

sim_model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2')
# torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

nlp_stanza = stanza.Pipeline(lang='en', processors='tokenize, pos, lemma, depparse')

def add_group_nums(sent):
    sent = re.sub(r"-", r"", sent)
    sent = re.sub(r"mrs.", r"mrs", sent)
    sent_nums = re.findall('\d*\.?\d+', sent)
    doc = nlp_stanza(sent)
    sent = nltk.word_tokenize(sent)
    
    final_ids = []
    assoc_nouns = []
    adjectives = []
    assoc_verbs = []
    rates = []
    
    offset = 0
    
    for s in doc.sentences:
        last_id = 0
        for word in s.words:
            if word.text in sent_nums:
                final_ids.append(offset + word.id-1)
                if offset + (word.id-1) - 1 >= 0 and sent[offset + (word.id-1) - 1] not in [',', '.', ';']:
                    final_ids.append(offset + (word.id-1) - 1)
                if offset + (word.id-1) + 1 < len(sent) and sent[offset + (word.id-1) + 1] not in [',', '.', ';']:
                    final_ids.append(offset + (word.id-1) + 1)
                if word.deprel in ['nummod', 'nmode']:
                    assoc_nouns.append(s.words[word.head-1].text)
                    final_ids.append(offset + word.head-1)
            if word.text in ['each', 'every', 'per']:
                rates.append(word.text)
                final_ids.append(offset + word.id-1)
            last_id = word.id
        offset += last_id
        
    offset = 0

    for s in doc.sentences:
        last_id = 0
        for word in s.words:
            if word.deprel == 'amod':
                if s.words[word.head-1].text in assoc_nouns:
                    adjectives.append(word.text)
                    final_ids.append(word.id-1)      
            if word.text in assoc_nouns and word.deprel in ['obj', 'nsubj']:
                assoc_verbs.append(s.words[word.head-1].text)
                final_ids.append(word.head-1)
            last_id = word.id
        offset += last_id
    
    if len(sent)-4 >= 0 and sent[len(sent)-4] not in [',', '.', ';']:
        final_ids.append(len(sent)-4)
    if len(sent)-3 >= 0 and sent[len(sent)-3] not in [',', '.', ';']:
        final_ids.append(len(sent)-3)
    if len(sent)-2 >= 0 and sent[len(sent)-2] not in [',', '.', ';']:
        final_ids.append(len(sent)-2)
                
    return list(set(final_ids))
# orig_text = "Rachel was organizing her book case making sure each of the shelves had exactly number0 books on it . If she had number1 shelves of mystery books and number2 shelves of picture books , how many books did she have total ?"
# bt_text = "Rachel arranged the bookwork to make sure that each of the shelves had exactly number0 books on it . If she had number1 shelves of mystery books and number2 shelves of painting books ,How many books did she have ?"
# given_that_text  = "How many books did she have total given that rachel was organizing her book case making sure each of the shelves had exactly number0 books on it and she had number1 shelves of mystery books and number2 shelves of picture books ?"
# group_nums = add_group_nums(given_that_text)
# print(group_nums)

# [13, 14, 15, 21, 22, 23, 27, 28, 29, 38, 39, 40]

def get_response(input_text,num_return_sequences,num_beams):
  batch = tokenizer([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt")
  translated = qtext_model.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5)
  tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
  return tgt_text


def make_dict(tokens):
    tok_id = {}
    for index,word in enumerate(tokens):
        word = snow_stemmer.stem(word)
        if word in tok_id: 
            tok_id[word].append(index)
        else:
            tok_id[word] = [index]
    return tok_id

def word_replacement(unmasked_sent):
    replace_word = unmasked_sent[0]['token_str']
    return replace_word


def candidate_selector(orignal_sent,candidates):
    ###Candidates = List of sentences.
    embeddings = sim_model.encode(candidates)
    orig_emb = sim_model.encode(orignal_sent)
    # print(embeddings.shape)
    result = [1 - spatial.distance.cosine(orig_emb, emb) for emb in embeddings]
    
    return result

def text_split(text):
  return nltk.tokenize.sent_tokenize(text)

def pre_process_text(text):
    # print(text)
    sent = nlp(text)
    tokens = [token.text for token in sent]
    # print(tokens)
    word_idx = {}
    for idx,word in enumerate(tokens):
        if word not in word_idx:
            word_idx[word] = [idx]
        else:
            word_idx[word].append(idx)
    nums = re.findall('number\d',text)
    num_replacement = {}
    for num in nums:
        num_replacement[num] = str(round(random.uniform(0, 1),2))
    for num in nums:
        if num in word_idx:
            num_list = word_idx[num]
        for id in num_list:
            tokens[id] = num_replacement[num]  
    pre_text = " ".join(tokens)
    return pre_text,num_replacement,word_idx

def post_process(text,num_replacement):
    sent = nlp(text)
    tokens = [token.text for token in sent]
    # print("text_token:",tokens)

    for key,value in num_replacement.items():
        indices = [index for index, element in enumerate(tokens) if element == value]
        for idx in indices:
            tokens[idx] = key
    return tokens

def gen_candidates_fill_mask(df,i):
    question = df.Question
    equation = df.Equation
    answer = df.Answer
    body = df.Body
    qtext = df.Ques_Statement
    nums = df.Numbers
    group_nums = df.group_nums
    idx = df.Index
    orig_eqn = equation[i]
    sent_qtext = qtext[i]
    # print("Q_TEXT: ",sent_qtext)
    # print("orig_quest: ",question[i])
        # print("orig_eqn: ",orig_eqn)
    par_q = get_response(sent_qtext,5,10)
    q_par = []
    for c in par_q:
        sent = nlp(c)
        tokens = [token.text for token in sent]
        c = " ".join(tokens)
        q_par.append(c)
    # print("Before: ",q_par)
    for c in q_par:
        if c == sent_qtext:
            q_par.remove(c)
    # print("After: ",q_par)
    candidate = []
    try:
        sent = nlp(body[i])
        tokens = [token.text for token in sent]
        body_token = tokens.copy()
        for q in q_par:
            sent = nlp(q)
            q_tokens = [token.text for token in sent]
            if tokens[-1] not in ["."]:
                tokens.extend(["."])
            tokens.extend(q_tokens)
            candidate.append(" ".join(tokens))
            tokens = body_token.copy()
    except:
        candidate = q_par

    cand_sent = []
    count = 0
    for c in candidate:
        count += 1
        sent = nlp(c)
        tokens = [token.text for token in sent]
        const_tokens = tokens.copy()
        # print("before:",tokens)
        for chunk in nltk.ne_chunk(nltk.pos_tag(tokens)):
                if hasattr(chunk, 'label'):
                    if chunk.label() == 'PERSON':
                        wi = ' '.join(c[0] for c in chunk)
                        const_tokens.remove(wi)
        # print("after:",tokens)
        tok_id = make_dict(tokens)
        filtered_sentence = set([snow_stemmer.stem(w) for w in const_tokens if not w.lower() in stop_words])
        # print("filtered_sent :",filtered_sentence)
        for k in range(3):
            replace_tok = random.sample(list(filtered_sentence),min(3,len(filtered_sentence)))
            # print("replace_tok",replace_tok)
            for w in replace_tok:
                index_list = tok_id[w]
                tokens[index_list[0]] = '[MASK]'
                masked_sent = " ".join(tokens)
                # print("masked: ",masked_sent)
                unmasked_sent = unmasker(masked_sent)
                replace_word = word_replacement(unmasked_sent)
                # print(f"for candidate {count}, {w} : {replace_word} ")
                for idx in index_list:
                    tokens[idx] = replace_word
            new_sent = " ".join(tokens)
            cand_sent.append(new_sent)
    cand_sent = list(set(cand_sent))
    cand_sent.append(question[i])
    sim_score = candidate_selector(question[i],cand_sent)
    return cand_sent,[str(orig_eqn)], sim_score


def gen_candidates_synonym(df,i):
    SEED = 1234
    random.seed(SEED+i)
    question = df.Question
    equation = df.Equation
    answer = df.Answer
    body = df.Body
    qtext = df.Ques_Statement
    nums = df.Numbers
    group_nums = df.group_nums
    idx = df.Index
    orig_eqn = equation[i]
    sent_qtext = qtext[i]
    # print("orig_quest: ",question[i])
    # print("orig_eqn: ",orig_eqn)
    q_par = get_response(sent_qtext,5,10)
    candidate = []
    try:
        sent = nlp(body[i])
        tokens = [token.text for token in sent]
        body_token = tokens.copy()
        for q in q_par:
            sent = nlp(q)
            q_tokens = [token.text for token in sent]
            if tokens[-1] not in ["."]:
                tokens.extend(["."])
            tokens.extend(q_tokens)
            candidate.append(" ".join(tokens))
            tokens = body_token.copy()
    except:
        candidate = q_par

    cand_sent = []
    count = 0
    for c in candidate:
        # print("c:",c)
        count += 1
        sent = nlp(c)
        tokens = [token.text for token in sent]
        const_tokens = tokens.copy()
        w_tokens = list(set([w for w in tokens if not w.lower() in stop_words]))
        # print("w_tokens :",w_tokens)
        synonyms = []
        word = random.sample(w_tokens,min(3,len(w_tokens)))
        ww = word.copy()
        # print("words:",word)
        
        for w in ww:
            # print("words:",word)
            # print("w:",w)
            try:
                syn_w = syn_model.most_similar(w, topn=3)
                # print("Word synonyms: ",w,syn_w)
                synonyms.append([p[0] for p in syn_w])
            # synonyms = list(set(synonyms))
            except:
                word.remove(w)
                # print("Word after Removed:",word)
        # print("synonyms:",synonyms)
        for m in range(len(synonyms)):
            for k in range(len(word)):
                # print("order: ",k,m)
                for idx,t in enumerate(tokens):
                    if t == word[k]:
                        tokens[idx] = synonyms[k][m]
            new_sent = " ".join(tokens)
            # print("new_sent:",new_sent)
            cand_sent.append(new_sent)
            tokens = const_tokens.copy()

    cand_sent = list(set(cand_sent))
    cand_sent.append(question[i])
    sim_score = candidate_selector(question[i],cand_sent)
    return cand_sent,[str(orig_eqn)], sim_score



def gen_candidates_reordering(df,i):
    question = df.Question
    # print("orig: ",question[i])
    equation = df.Equation
    answer = df.Answer
    body = df.Body
    qtext = df.Ques_Statement
    # print("qtext:",qtext)
    sent_qtext = qtext[i]
    # print("sent_qtext",sent_qtext)
    nums = df.Numbers
    group_nums = df.group_nums
    idx = df.Index
    orig_eqn = equation[i]
    q_par = get_response(sent_qtext,5,10)
    # print("q_par:",q_par)
    if_cand_sent = []
    given_cand_sent = []
    for qtext_text in q_par:
        body_text = text_split(body[i])
        for idx,t in enumerate(body_text):
            sent = nlp(t)
            tokens = [token.text for token in sent]
            # print(tokens)
            tokens[0] = tokens[0].lower()
            # print("Tokens:",tokens)
            if tokens[0] in [" "]:
                tokens = tokens[1:]
            if tokens[0] in ["if","IF","If"," "]:
                tokens = tokens[1:]
            if tokens[-1] in [".",",","?"]:
                tokens = tokens[:len(tokens)-1]
            tokens[0] = tokens[0].lower()
            t = " ".join(tokens)
            body_text[idx] = t
        sent = nlp(qtext_text)
        tokens = [token.text.lower() for token in sent]
        qtext_text = " ".join(tokens)
        new_question_if = "If " + " and ".join(body_text) + " then " + qtext_text
        if_cand_sent.append(new_question_if)

        body_text = text_split(body[i])
        for idx,t in enumerate(body_text):
            sent = nlp(t)
            tokens = [token.text for token in sent]
            tokens[0] = tokens[0].lower()
            if tokens[0] in [" "]:
                tokens = tokens[1:]
            if tokens[0] in ["if","IF","If"," "]:
                tokens = tokens[1:]
            if tokens[-1] in [".",",","?"]:
                tokens = tokens[:len(tokens)-1]
            tokens[0] = tokens[0].lower()
            if i == len(body_text) -1:
                tokens[-1] = "?"
            t = " ".join(tokens)
            body_text[idx] = t
        sent = nlp(qtext_text)
        tokens = [token.text for token in sent]
        tokens = tokens[0:len(tokens)-1]
        tokens[0] = tokens[0].capitalize()
        qtext_text = " ".join(tokens)
        # new_question_if = "If " + " and ".join(body_text) + " then " + qtext_text
        new_question_given = qtext_text + " given that " + " and ".join(body_text) +" ."
        given_cand_sent.append(new_question_given)
    if_cand_sent = list(set(if_cand_sent))
    if_cand_sent.append(question[i])
    given_cand_sent = list(set(given_cand_sent))
    given_cand_sent.append(question[i])
    if_sim_score = candidate_selector(question[i],if_cand_sent)
    given_sim_score = candidate_selector(question[i],given_cand_sent)
    return if_cand_sent,given_cand_sent,[str(orig_eqn)], if_sim_score,given_sim_score


def gen_candidates_ner(df,i):
    SEED = 1234
    random.seed(SEED + i)
    question = df.Question
    # print("orig: ",question[i])
    equation = df.Equation
    answer = df.Answer
    body = df.Body
    qtext = df.Ques_Statement
    nums = df.Numbers
    group_nums = df.group_nums
    idx = df.Index
    orig_eqn = equation[i]
    sent_qtext = qtext[i]
    q_par = get_response(sent_qtext,5,10)
    candidate = []
    try:
        sent = nlp(body[i])
        tokens = [token.text for token in sent]
        body_token = tokens.copy()
        for q in q_par:
            sent = nlp(q)
            q_tokens = [token.text for token in sent]
            if tokens[-1] not in ["."]:
                tokens.extend(["."])
            tokens.extend(q_tokens)
            candidate.append(" ".join(tokens))
            tokens = body_token.copy()
    except:
        candidate = q_par
    
    cand_sent = []
    for text in candidate:

        sent = nlp(text)
        tokens = [token.text for token in sent]
        const_tokens = tokens.copy()
        # print("Text:",text)
        # body_text = body[index]
        # qtext_text = qtext[index]
        name = []
        place = []
        for sent in nltk.sent_tokenize(text):
            for chunk in nltk.ne_chunk(nltk.pos_tag(nltk.word_tokenize(sent))):
                if hasattr(chunk, 'label'):
                    if chunk.label() == 'PERSON':
                        name.append(" ".join(c[0] for c in chunk))
                    if chunk.label() == 'GPE':
                        place.append(" ".join(c[0] for c in chunk))
        name = set(name)
        # print("name:",name)
        place = set(place)
        place = place.difference(name) 
        if place or name:
            # count +=1
            # replace_name = random.sample(name_list,3)
            # replace_city = random.sample(city_list,3)
            # print("re name",replace_name)
            # print("re city",replace_city)
            for z in range(3):
                # print("name: ",rep_name)
                # print("city :",rep_city)
                if name:
                    for n in name:
                        rep_name = random.choice(name_list)
                        name_index = [idx for idx, t in enumerate(tokens) if t == n]
                        for idx in name_index:
                            tokens[idx] =  rep_name
                if place:
                    for p in place:
                        place_index = [idx for idx, t in enumerate(tokens) if t == p]
                        rep_city = random.sample(city_list,3)
                        for idx in place_index:
                            tokens[idx] =  rep_city
                new_question = " ".join(tokens)
                # print(new_question)
                cand_sent.append(new_question)
                tokens = const_tokens.copy()
    # print("cand : ",cand_sent)
    # print("Length : ",len(cand_sent))
    cand_sent = list(set(cand_sent))
    cand_sent.append(question[i])
    sim_score = candidate_selector(question[i],cand_sent)
    return cand_sent,[str(orig_eqn)], sim_score
        

def gen_candidates_backtranslation(df,i):
    question = df.Question
    # print("orig: ",question[i])
    equation = df.Equation
    answer = df.Answer
    body = df.Body
    qtext = df.Ques_Statement
    nums = df.Numbers
    group_nums = df.group_nums
    idx = df.Index
    orig_eqn = equation[i]
    sent_qtext = qtext[i]
    q_par = get_response(sent_qtext,5,10)
    candidate = []
    try:
        sent = nlp(body[i])
        tokens = [token.text for token in sent]
        body_token = tokens.copy()
        for q in q_par:
            sent = nlp(q)
            q_tokens = [token.text for token in sent]
            if tokens[-1] not in ["."]:
                tokens.extend(["."])
            tokens.extend(q_tokens)
            candidate.append(" ".join(tokens))
            tokens = body_token.copy()
    except:
        candidate = q_par
    cand_sent = []
    for text in candidate:
        text,num_replacement,word_idx = pre_process_text(text)
        text_comp = text_split(text)
        ftranslated_text = []
        rtranslated_text = []
        for t in text_comp:
            de_text = translator_en_de(t)
            fr_text = translator_de_fr(de_text[0]['translation_text'])
            ftrans_text = translator_fr_en(fr_text[0]['translation_text'])
            ftrans_text = ftrans_text[0]['translation_text']
            ftranslated_text.append(ftrans_text)
            ru_text = translator_en_ru(t)
            rtrans_text = translator_ru_en(ru_text[0]['translation_text'])
            rtrans_text = rtrans_text[0]['translation_text']
            rtranslated_text.append(rtrans_text)

        rnew_text = " ".join(rtranslated_text)
        fnew_text = " ".join(ftranslated_text)
        rtranslated_text = post_process(rnew_text,num_replacement)
        ftranslated_text = post_process(fnew_text,num_replacement)
        rnew_text = " ".join(rtranslated_text)
        fnew_text = " ".join(ftranslated_text)
        cand_sent.append(rnew_text)
        cand_sent.append(fnew_text)
    cand_sent = list(set(cand_sent))
    cand_sent.append(question[i])
    sim_score = candidate_selector(question[i],cand_sent)
    return cand_sent,[str(orig_eqn)], sim_score




# df = pd.read_csv("../mawps/MaWPS.csv")
# cand_sent,_,sim_score = gen_candidates_fill_mask(df,21)
# print(cand_sent,len(cand_sent))
# print("Sim-score: ",sim_score,len(sim_score))
    # cand_sent = []
    # count = 0
    # for c in candidate:
    #     count += 1
    #     sent = nlp(c)
    #     tokens = [token.text for token in sent]
    #     tok_id = make_dict(c)
    #     filtered_sentence = set([snow_stemmer.stem(w) for w in tokens if not w.lower() in stop_words])
    #     # print("filtered_sent :",filtered_sentence)
    #     for k in range(3):
    #         replace_tok = random.sample(list(filtered_sentence),min(3,len(filtered_sentence)))
    #         for w in replace_tok:
    #             index_list = tok_id[w]
    #             tokens[index_list[0]] = '[MASK]'
    #             masked_sent = " ".join(tokens)
    #             # print("masked: ",masked_sent)
    #             unmasked_sent = unmasker(masked_sent)
    #             replace_word = word_replacement(unmasked_sent)
    #             # print(f"for candidate {count}, {w} : {replace_word} ")
    #             for idx in index_list:
    #                 tokens[idx] = replace_word
    #         new_sent = " ".join(tokens)
    #         cand_sent.append(new_sent)
    # cand_sent = list(set(cand_sent))
    # cand_sent.append(question[i])
    










        
